/**
  ******************************************************************************
  * @file    main.cc
  * @author  xinkai yu
  * @version V1.0.1
  * @date    2022/02/18
  * @brief   yolo & depth camera based detector for MobiRo gen3 @ TIB330
  ******************************************************************************
  * @attention
  *
  ******************************************************************************
  */

#include <iostream>
#include <memory>
#include <thread>
#include <cmath>
#include <string>
#include <vector>
#include <fstream>
#include <sys/time.h>

#include <Eigen/Core>
#include <Eigen/Dense>
#include <Eigen/Geometry>

#include <opencv2/opencv.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

#include "ros/ros.h"
#include "std_srvs/SetBool.h"
#include "sensor_msgs/Image.h"
#include "cv_bridge/cv_bridge.h"
#include "image_transport/image_transport.h"
#include "message_filters/subscriber.h"
#include "message_filters/time_synchronizer.h"
#include "message_filters/sync_policies/approximate_time.h"
#include "tf/transform_listener.h"
#include "tf_conversions/tf_eigen.h"
#include "jsk_recognition_msgs/BoundingBox.h"
#include "std_msgs/Float64MultiArray.h" 

#define OPENCV
#include "yolo_v2_class.hpp"

using namespace Eigen;

struct Frame {
  cv_bridge::CvImagePtr rgb_image;  // simple rgb image 
  cv_bridge::CvImagePtr depth_map;  // zed depth map for distance detect
  double time_stamp;                // time in ms, from init to now
};

class FrameBuffer {
 public:
  FrameBuffer(size_t size): frames_(size), mutexs_(size) {
    front_idx_ = 0; 
    rear_idx_ = 0; 
    last_get_timestamp_ = 0.0;
  }
  ~FrameBuffer() = default;

  bool Push(const Frame& frame) {
    // 1, 2, ..., size - 1, 0
    const size_t new_front_idx = (front_idx_ + 1) % frames_.size();  
    std::unique_lock<std::timed_mutex> lock(mutexs_[new_front_idx], 
                                            std::chrono::milliseconds(2));
    // try for 2ms to lock, try not to be blocked
    if (!lock.owns_lock()) return false; 
    frames_[new_front_idx] = frame;
    // cover the oldest one in rear
    if (new_front_idx == rear_idx_) {
      rear_idx_ = (rear_idx_ + 1) % frames_.size();
    }
    front_idx_ = new_front_idx;
    return true;
  }

  bool GetLatest(Frame& frame) {
    volatile const size_t front_idx = front_idx_;  // meaning of volatile?
    //try for 2ms to lock
    std::unique_lock<std::timed_mutex> lock(mutexs_[front_idx], 
                                            std::chrono::milliseconds(2));
    if (!lock.owns_lock() || 
        frames_[front_idx].time_stamp == last_get_timestamp_ ||
        frames_[front_idx].rgb_image.get() == nullptr) { 
      return false;  // !.empty() avoid blinking
    }

    frame = frames_[front_idx];
    last_get_timestamp_ = frames_[front_idx].time_stamp;  // get too fast!
    return true;
  }

 private:
  std::vector<Frame> frames_;
  std::vector<std::timed_mutex> mutexs_;

  size_t front_idx_;
  size_t rear_idx_;
  double last_get_timestamp_;
};

class ArmDetector {
 public:
  typedef message_filters::sync_policies::ApproximateTime< \
      sensor_msgs::Image, sensor_msgs::Image> sync_policy;

 public:
  explicit ArmDetector() {
    this->DetectorInit();
    it_ = std::make_unique<image_transport::ImageTransport>(nh_);
    camera_rgb_sub_.subscribe(nh_, "/zedm/zed_node/left/image_rect_color", 5);
    camera_depth_sub_.subscribe(
        nh_, "/zedm/zed_node/depth/depth_registered", 5);
    sync_.connectInput(camera_rgb_sub_, camera_depth_sub_);
    sync_.registerCallback(
        boost::bind(&ArmDetector::CameraRgbdCallback, this, _1, _2));

    switch_service_ = nh_.advertiseService("arm_detector/set_detect_status", 
                                           &ArmDetector::DetectSwitchCallback, 
                                           this);//这是添加了一个服务，决定是否进行检测

    detect_pub_ = it_->advertise("/arm_detect_result", 5);
    //这个发布目标位置的框选
    target_visual_pub_ = nh_.advertise<jsk_recognition_msgs::BoundingBox>(
        "grasp_target", 1);
    //这是发布目标位置的转换
    // target_pose_pub_ = nh_.advertise<geometry_msgs::TransformStamped>(
    //     "target_pose", 1);

    workspace_pub_=nh_.advertise<std_msgs::Float64MultiArray>("workspace",1);

    init_time_ = ros::Time::now();
    frame_buffer_ = std::make_unique<FrameBuffer>(10);
    detect_enable_flag_ = true;
    //默认可以检测，因为一开始又不运行
    detect_thread_ = std::thread(&ArmDetector::DetectRealize, this);
    // visual_thread_ = std::thread(&ArmDetector::ImageVisualize, this);
  }

  ~ArmDetector() {
    if (detect_thread_.joinable()) detect_thread_.join();
    // if (visual_thread_.joinable()) visual_thread_.join();
  }

  void ImageVisualize() {  // for debug only 
    ros::Rate loop_rate(100);
    while (ros::ok()) {
      loop_rate.sleep();
    }
  }

  void DetectRealize() {
    ros::Rate loop_rate(100);
    Frame detect_frame;
    while (ros::ok()) {
      if (!detect_enable_flag_ || !frame_buffer_->GetLatest(detect_frame)) {
        loop_rate.sleep();
        continue;
      }//这就直接不往下运行，减少算力
        
      cv_bridge::CvImagePtr out_ptr = detect_frame.rgb_image;
      if (out_ptr->image.empty()) {
        loop_rate.sleep();
        continue;
      }

      auto detect_outs = detector->detect(out_ptr->image, 0.2);
      cv::Point center(639.41, 356.83);
      float depth_value = 999.9;
      for (auto it = detect_outs.begin(); it != detect_outs.end(); it++) {
        if (it->obj_id != 0) continue;  // detect can
        center.x = it->x + it->w / 2;
        center.y = it->y + it->h / 2;
        depth_value = detect_frame.depth_map->image.at<float>(
            center.y, center.x);
        if (std::isnan(depth_value) || std::isinf(depth_value)) continue;
        cv::rectangle(out_ptr->image, cv::Point(it->x, it->y), 
                      cv::Point(it->x + it->w, it->y + it->h), 
                      cv::Scalar(0, 0, 255), 2);  // box all detected target
      }
      if (depth_value<0.8 && !std::isnan(depth_value) && 
          !std::isinf(depth_value)) {
        detect_pub_.publish(out_ptr->toImageMsg());
        // ROS_INFO("detect done! z: %lf", depth_value);
        jsk_recognition_msgs::BoundingBox target_box;
        target_box.header.frame_id = "zedm_left_camera_frame";
        target_box.header.stamp = 
          ros::Time().fromSec(detect_frame.time_stamp + init_time_.toSec());
        target_box.dimensions.x = 0.07;
        target_box.dimensions.y = 0.07;
        target_box.dimensions.z = 0.12;

        target_box.pose.position.x = depth_value;
        target_box.pose.position.y = 
          depth_value * (639.41 - center.x) / 699.55;
        target_box.pose.position.z = 
          depth_value * (356.83 - center.y) / 699.55;
        this->ArmInverseKine(target_box);
        std_msgs::Float64MultiArray workspace;
        workspace.data.resize(6);
        workspace.data[0]=target_box.pose.position.x-0.1;
        workspace.data[1]=target_box.pose.position.x+0.06;
        workspace.data[2]=target_box.pose.position.y-0.06;
        workspace.data[3]=target_box.pose.position.y+0.06;
        workspace.data[4]=target_box.pose.position.z-0.09;
        workspace.data[5]=target_box.pose.position.z+0.09;
        workspace_pub_.publish(workspace);
        target_visual_pub_.publish(target_box);
      }
      loop_rate.sleep();
    }
  }

 private:
  void CameraRgbdCallback(const sensor_msgs::Image::ConstPtr &image,
                          const sensor_msgs::Image::ConstPtr &depth) {
    cv_bridge::CvImagePtr rgb_ptr, depth_ptr;
    try {
      depth_ptr = cv_bridge::toCvCopy(
        depth, sensor_msgs::image_encodings::TYPE_32FC1);
      rgb_ptr = cv_bridge::toCvCopy(image, sensor_msgs::image_encodings::BGR8);
    }
    catch (cv_bridge::Exception& e) {
      ROS_ERROR("cv_bridge exception: %s", e.what());
      return;
    }
    double time_stamp = (ros::Time::now() - init_time_).toSec();
    frame_buffer_->Push({rgb_ptr, depth_ptr, time_stamp});
    // ROS_INFO("Rgbd msg got!");                       
  }

  bool DetectSwitchCallback(std_srvs::SetBoolRequest &request, 
                            std_srvs::SetBoolResponse &response) {
    detect_enable_flag_ = request.data;//通过这个服务决定是否进行检测
    response.success = true;
    if (detect_enable_flag_) response.message = "Detector enable!";
    else response.message = "Detector disable!";
    return true;
  }

  void DetectorInit() {
    std::string ws_root = "/home/k331/projects/catkin_grasp";
    std::string classes_file = ws_root + "/yolo/obj.names";
    std::string model_config = ws_root + "/yolo/yolo-obj.cfg";
    std::string model_weights = ws_root + "/yolo/yolo-obj_last.weights";

    std::ifstream ifs(classes_file.c_str());
    std::string line;
    while (getline(ifs, line)) classes_.push_back(line);
    detector = std::make_unique<Detector>(model_config, model_weights, 0);
    std::cout << "Yolo init success!" << std::endl;
  }

  void ArmInverseKine(jsk_recognition_msgs::BoundingBox &target_box) {
    tf::StampedTransform transform;
    try {
      tf_listener_.lookupTransform("/base_link", "/zedm_left_camera_frame", 
                                   ros::Time(0), transform);//0表示最新的数据
    }
    catch (tf::TransformException &ex) {
      ROS_ERROR("%s",ex.what());
      ros::Duration(1.0).sleep();
      return;
    }

    tf::Quaternion inv_quat = transform.getRotation().inverse();
    tf::quaternionTFToMsg(inv_quat, target_box.pose.orientation);//做成逆就是相对于坐标系下的了
    Vector3d target_position_eye(target_box.pose.position.x,
                                 target_box.pose.position.y,
                                 target_box.pose.position.z);
    
    Quaterniond eigen_quat;  // should be Eigen::Quaterniond
    tf::quaternionTFToEigen(transform.getRotation(), eigen_quat);
    Vector3d eye_trans;
    tf::vectorTFToEigen(transform.getOrigin(), eye_trans);
    Vector3d target_position_base = eigen_quat.matrix() * target_position_eye + 
                                    eye_trans;  // trans target into base
    geometry_msgs::TransformStamped target_pose;
    target_pose.header.frame_id = "base_link";
    target_pose.header.stamp = target_box.header.stamp;
    target_pose.child_frame_id = "target_link";
    target_pose.transform.translation.x = target_position_base(0);
    target_pose.transform.translation.y = target_position_base(1);
    target_pose.transform.translation.z = target_position_base(2);
    // target_pose_pub_.publish(target_pose);//
  }

  ros::NodeHandle nh_;
  std::unique_ptr<image_transport::ImageTransport> it_;
  message_filters::Subscriber<sensor_msgs::Image> camera_rgb_sub_;
  message_filters::Subscriber<sensor_msgs::Image> camera_depth_sub_;
  message_filters::Synchronizer<sync_policy> sync_{sync_policy(10)};
  image_transport::Publisher detect_pub_;
  ros::Publisher target_visual_pub_;
  ros::Publisher target_pose_pub_;
  ros::Publisher workspace_pub_;
  ros::ServiceServer switch_service_;
  tf::TransformListener tf_listener_;
  std::thread detect_thread_;
  std::thread visual_thread_;
  ros::Time init_time_;
  std::unique_ptr<FrameBuffer> frame_buffer_;
  std::unique_ptr<Detector> detector;  // yolo based detector
  std::vector<std::string> classes_;
  bool detect_enable_flag_;
};

int main(int argc, char **argv) {
  ros::init(argc, argv, "arm_detect_node");
  ArmDetector arm_detector;
  ros::spin();
  return 0;
}